# Copyright (c) Facebook, Inc. and its affiliates.

import copy
import io
import logging
import numpy as np
from typing import List
import onnx
import torch
from caffe2.proto import caffe2_pb2
from caffe2.python import core
from caffe2.python.onnx.backend import Caffe2Backend
from tabulate import tabulate
from termcolor import colored
from torch.onnx import OperatorExportTypes

from .shared import (
    ScopedWS,
    construct_init_net_from_params,
    fuse_alias_placeholder,
    fuse_copy_between_cpu_and_gpu,
    get_params_from_init_net,
    group_norm_replace_aten_with_caffe2,
    infer_device_type,
    remove_dead_end_ops,
    remove_reshape_for_fc,
    save_graph,
)

logger = logging.getLogger(__name__)


def export_onnx_model(model, inputs):
    """
    Trace and export a model to onnx format.

    Args:
        model (nn.Module):
        inputs (tuple[args]): the model will be called by `model(*inputs)`

    Returns:
        an onnx model
    """
    assert isinstance(model, torch.nn.Module)

    # make sure all modules are in eval mode, onnx may change the training state
    # of the module if the states are not consistent
    def _check_eval(module):
        assert not module.training

    model.apply(_check_eval)

    # Export the model to ONNX
    with torch.no_grad():
        with io.BytesIO() as f:
            torch.onnx.export(
                model,
                inputs,
                f,
                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                # verbose=True,  # NOTE: uncomment this for debugging
                # export_params=True,
            )
            onnx_model = onnx.load_from_string(f.getvalue())

    # Apply ONNX's Optimization
    all_passes = onnx.optimizer.get_available_passes()
    passes = ["fuse_bn_into_conv"]
    assert all(p in all_passes for p in passes)
    onnx_model = onnx.optimizer.optimize(onnx_model, passes)
    return onnx_model


def _op_stats(net_def):
    type_count = {}
    for t in [op.type for op in net_def.op]:
        type_count[t] = type_count.get(t, 0) + 1
    type_count_list = sorted(type_count.items(), key=lambda kv: kv[0])  # alphabet
    type_count_list = sorted(type_count_list, key=lambda kv: -kv[1])  # count
    return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)


def _assign_device_option(
    predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
):
    """
    ONNX exported network doesn't have concept of device, assign necessary
    device option for each op in order to make it runable on GPU runtime.
    """

    def _get_device_type(torch_tensor):
        assert torch_tensor.device.type in ["cpu", "cuda"]
        assert torch_tensor.device.index == 0
        return torch_tensor.device.type

    def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
        for op, ssa_i in zip(net_proto.op, net_ssa):
            if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
                op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
            else:
                devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
                assert all(d == devices[0] for d in devices)
                if devices[0] == "cuda":
                    op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))

    # update ops in predict_net
    predict_net_input_device_types = {
        (name, 0): _get_device_type(tensor)
        for name, tensor in zip(predict_net.external_input, tensor_inputs)
    }
    predict_net_device_types = infer_device_type(
        predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
    )
    predict_net_ssa, _ = core.get_ssa(predict_net)
    _assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)

    # update ops in init_net
    init_net_ssa, versions = core.get_ssa(init_net)
    init_net_output_device_types = {
        (name, versions[name]): predict_net_device_types[(name, 0)]
        for name in init_net.external_output
    }
    init_net_device_types = infer_device_type(
        init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
    )
    _assign_op_device_option(init_net, init_net_ssa, init_net_device_types)


def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
    """
    Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.

    Arg:
        model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
        tensor_inputs: a list of tensors that caffe2 model takes as input.
    """
    model = copy.deepcopy(model)
    assert isinstance(model, torch.nn.Module)
    assert hasattr(model, "encode_additional_info")

    # Export via ONNX
    logger.info(
        "Exporting a {} model via ONNX ...".format(type(model).__name__)
        + " Some warnings from ONNX are expected and are usually not to worry about."
    )
    onnx_model = export_onnx_model(model, (tensor_inputs,))
    # Convert ONNX model to Caffe2 protobuf
    init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
    ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
    table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
    logger.info(
        "ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
    )

    # Apply protobuf optimization
    fuse_alias_placeholder(predict_net, init_net)
    if any(t.device.type != "cpu" for t in tensor_inputs):
        fuse_copy_between_cpu_and_gpu(predict_net)
        remove_dead_end_ops(init_net)
        _assign_device_option(predict_net, init_net, tensor_inputs)
    params, device_options = get_params_from_init_net(init_net)
    predict_net, params = remove_reshape_for_fc(predict_net, params)
    init_net = construct_init_net_from_params(params, device_options)
    group_norm_replace_aten_with_caffe2(predict_net)

    # Record necessary information for running the pb model in Detectron2 system.
    model.encode_additional_info(predict_net, init_net)

    logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
    logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))

    return predict_net, init_net


def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
    """
    Run the caffe2 model on given inputs, recording the shape and draw the graph.

    predict_net/init_net: caffe2 model.
    tensor_inputs: a list of tensors that caffe2 model takes as input.
    graph_save_path: path for saving graph of exported model.
    """

    logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
    save_graph(predict_net, graph_save_path, op_only=False)

    # Run the exported Caffe2 net
    logger.info("Running ONNX exported model ...")
    with ScopedWS("__ws_tmp__", True) as ws:
        ws.RunNetOnce(init_net)
        initialized_blobs = set(ws.Blobs())
        uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
        for name, blob in zip(uninitialized, tensor_inputs):
            ws.FeedBlob(name, blob)

        try:
            ws.RunNetOnce(predict_net)
        except RuntimeError as e:
            logger.warning("Encountered RuntimeError: \n{}".format(str(e)))

        ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
        blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}

        logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
        save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)

        return ws_blobs
